#!/usr/bin/env python3

# Copyright (c) 2023, Arm Limited and Contributors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
from enum import Enum
from functools import partial
import logging  # noqa: F401
import os
import pytest
import socket
from threading import Thread
import time
from typing import Any, Coroutine, Optional

from pw_hdlc.rpc import HdlcRpcClient, default_channels
from pw_status import Status as PwStatus

from pyedmgr import (  # noqa: F401
    AbstractChannel,
    SynchronousSocketChannel,
    TestCaseContext,
    TestDevice,
    fixture_test_case,  # noqa: F811
)

PROTO = os.environ["PROTO"]
RPC_READ_SIZE = 1
LOG_TERMINAL = 0
RPC_TERMINAL = 1


class AddressFamily(Enum):
    """Wrapper for IoT Socket address family"""

    UNSPECIFIED = 0
    INET = 1
    INET6 = 2

    def to_af_inet(self):
        """Convert to BSD socket address family"""
        if self.value == AddressFamily.INET.value:
            return socket.AF_INET
        elif self.value == AddressFamily.INET6.value:
            return socket.AF_INET6
        return socket.AF_UNSPEC


class SocketType(Enum):
    """Wrapper for IoT Socket socket type"""

    UNSPECIFIED = 0
    STREAM = 1
    DGRAM = 2


class ProtocolType(Enum):
    """Wrapper for IoT Socket protocol type"""

    UNSPECIFIED = 0
    TCP = 1
    UDP = 2


class SocketResult(Enum):
    """Return value from IoT Socket"""

    ERROR = -1
    ESOCK = -2
    EINVAL = -3
    ENOTSUP = -4
    ENOMEM = -5
    EAGAIN = -6
    EINPROGRESS = -7
    ETIMEDOUT = -8
    EISCONN = -9
    ENOTCONN = -10
    ECONNREFUSED = -11
    ECONNRESET = -12
    ECONNABORTED = -13
    EALREADY = -14
    EADDRINUSE = -15
    EHOSTNOTFOUND = -16


class PwRpcError(RuntimeError):
    """Exception raised when Pigweed RPC returns non-OK-status"""

    def __init__(self, fn_name: str, pw_status: PwStatus):
        super().__init__(f"{fn_name}: {pw_status.name}")


class IotSocketError(RuntimeError):
    """Exception raised when IoT Socket returns an error status"""

    def __init__(self, fn_name: str, iot_status: SocketResult):
        super().__init__(f"{fn_name}: {iot_status.name} ({iot_status.value})")


def raise_if_error(fn_name: str, pw_status: PwStatus, iot_status: int):
    if not pw_status.ok():
        raise PwRpcError(fn_name, pw_status)
    if iot_status < 0:
        raise IotSocketError(fn_name, SocketResult(iot_status))


class AcceptResult:
    def __init__(self, client_ip, client_port, client_socket):
        self.client_ip = client_ip
        self.client_port = client_port
        self.client_socket = client_socket


class IotSocketClient:
    """Wrapper for the IotSocketService RPC service.
    It can act as a server or a client. When acting as server it can only accept
    a single client. The accepted socket is automatically used in send and recv.
    """

    @classmethod
    def from_test_device(cls, device: TestDevice, logger: logging.Logger):
        """Construct from TestDevice"""
        channel = device.controller.get_channel(terminal=RPC_TERMINAL, sync=True)
        channel.open()
        return cls(channel, logger)

    def __init__(self, channel: SynchronousSocketChannel, logger: logging.Logger):
        self._logger = logger

        client = HdlcRpcClient(
            partial(channel.read, RPC_READ_SIZE),
            [PROTO],
            default_channels(channel.write),
        )

        # Wait for proto compilation
        time.sleep(1)

        self._service = client.rpcs().iotsdk.socket.pw_rpc.IotSocketService
        self.socket: Optional[int] = None
        self.af: Optional[AddressFamily] = None

        # When self.socket is a bound socket, self._accepted_socket is the
        # socket returned by accept().
        self._accepted_socket: Optional[int] = None

    def get_ip(self, af: AddressFamily) -> str:
        status, response = self._service.GetLocalIp()
        raise_if_error("get_ip", status, response.status)
        return socket.inet_ntop(af.to_af_inet(), response.ip[: response.ip_len])

    def create(
        self, af: AddressFamily, type_: SocketType, protocol: ProtocolType
    ) -> int:
        status, response = self._service.Create(
            af=af.value, type=type_.value, protocol=protocol.value
        )
        raise_if_error("create", status, response.status)
        self.socket = response.status
        self.af = af
        return self.socket

    def bind(self, host: str, port: int):
        assert self.socket is not None

        host_bytes = socket.inet_pton(self.af.to_af_inet(), host)

        status, response = self._service.Bind(
            socket=self.socket, ip=host_bytes, ip_len=len(host_bytes), port=port
        )
        raise_if_error("bind", status, response.status)
        self._logger.debug(f"Bound to {host}:{port}")

    def listen(self):
        assert self.socket is not None

        status, response = self._service.Listen(socket=self.socket, backlog=1)
        raise_if_error("listen", status, response.status)
        self._logger.debug("Listening")

    def accept(self) -> AcceptResult:
        assert self.socket is not None

        status, response = self._service.Accept(socket=self.socket)
        raise_if_error("accept", status, response.status)

        ip_str = socket.inet_ntop(self.af.to_af_inet(), response.ip[: response.ip_len])
        self._accepted_socket = response.status
        self._logger.debug(
            f"Accepted connection from {ip_str}:{response.port} on "
            + f"socket {self._accepted_socket}"
        )
        return AcceptResult(ip_str, response.port, self._accepted_socket)

    def connect(self, host: str, port: int):
        assert self.socket is not None

        host_bytes = socket.inet_pton(self.af.to_af_inet(), host)

        status, response = self._service.Connect(
            socket=self.socket, ip=host_bytes, ip_len=len(host_bytes), port=port
        )
        raise_if_error("connect", status, response.status)

        self._logger.debug(f"Connected to {host}:{port} on socket {response.status}")
        return response.status

    def send(self, buf: bytes):
        if self._accepted_socket is not None:
            socket = self._accepted_socket
        else:
            socket = self.socket

        assert socket is not None

        status, response = self._service.Send(socket=socket, buf=buf, len=len(buf))
        raise_if_error("send", status, response.status)

        self._logger.debug(f"Sent {len(buf)} bytes")

    def recv(self, n: int) -> bytes:
        if self._accepted_socket is not None:
            socket = self._accepted_socket
        else:
            socket = self.socket

        assert socket is not None

        status, response = self._service.Recv(socket=socket, len=n)
        raise_if_error("recv", status, response.status)

        self._logger.debug(f"Received {len(response.buf)} bytes")
        return response.buf

    def close(self):
        # Close accepted socket if any
        if self._accepted_socket is not None:
            status, response = self._service.Close(socket=self._accepted_socket)
            raise_if_error("close", status, response.status)
            self._accepted_socket = None

        # Close main socket which must be open
        assert self.socket is not None
        status, response = self._service.Close(socket=self.socket)
        raise_if_error("close", status, response.status)
        self.socket = None

        self._logger.debug("Socket(s) closed")


async def accept_async(device: TestDevice) -> Coroutine[Any, Any, AcceptResult]:
    """
    Instruct the server to accept a connection.
    Await the awaitable to block until the connection is accepted.
    """
    return await asyncio.to_thread(device.accept)


def log_fvp_output_in_background(server: TestDevice, client: TestDevice):
    """Log FVP output in a background thread"""

    def run_for_device(device: TestDevice, name: str):
        device.channel.close()
        channel = device.controller.get_channel(terminal=LOG_TERMINAL, sync=True)
        channel.open()
        logger = logging.getLogger(name)
        while True:
            logger.info(channel.readline().decode("utf-8"))

    Thread(target=run_for_device, args=(client, "client"), daemon=True).start()
    Thread(target=run_for_device, args=(server, "server"), daemon=True).start()


def get_server_ip(server: IotSocketClient) -> str:
    """Retrieve the IPv6 address configured by the device acting as server"""
    logging.info("Waiting for server network to be ready")
    while True:
        try:
            ip = server.get_ip(AddressFamily.INET6)
        except IotSocketError:
            pass
        else:
            break
    logging.info(f"Server IP set to {ip}")
    return ip


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "fixture_test_case", ["@test.json"], indirect=["fixture_test_case"]
)
async def test__socket_client_and_server(
    fixture_test_case: TestCaseContext,  # noqa: F811
):
    async with fixture_test_case as context:
        assert len(context.allocated_devices) == 2

        bind_address = "::"
        port = 9001
        msg_server2client = b"hello client"
        msg_client2server = b"hello server"

        server_device = context.allocated_devices[0]
        client_device = context.allocated_devices[1]

        # Log FVP output
        log_fvp_output_in_background(server_device, client_device)

        # Create server socket
        server: IotSocketClient = IotSocketClient.from_test_device(
            server_device, logging.getLogger("server")
        )
        server.create(AddressFamily.INET6, SocketType.STREAM, ProtocolType.TCP)
        server_ip = get_server_ip(server)

        # Create client socket
        client = IotSocketClient.from_test_device(
            client_device, logging.getLogger("client")
        )
        client.create(AddressFamily.INET6, SocketType.STREAM, ProtocolType.TCP)

        # Server binds to :: and waits for connection (in a background thread)
        server.bind(bind_address, port)
        server.listen()
        accept_task = accept_async(server)

        # Allow server time to be ready
        time.sleep(1)

        # Client connects to the server. Blocks until the server accept()s the
        # connection.
        client.connect(server_ip, port)
        await accept_task

        # Server accepts the connection and sends its message
        server.send(msg_server2client)

        # Client receives the message
        res = client.recv(len(msg_server2client))
        assert res == msg_server2client

        # Client sends a message back
        client.send(msg_client2server)

        # Server receives the message
        res = server.recv(len(msg_client2server))
        assert res == msg_client2server
